/*
 * Copyright 2002-2022 the original author or authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      https://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.springframework.test.context.aot;

import java.lang.reflect.Method;
import java.util.Map;
import java.util.function.Supplier;

import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;

/**
 * {@code AotTestMappings} provides mappings from test classes to AOT-optimized
 * context initializers.
 *
 * <p>If a test class is not {@linkplain #isSupportedTestClass(Class) supported} in
 * AOT mode, {@link #getContextInitializer(Class)} will return {@code null}.
 *
 * <p>Reflectively accesses {@link #GENERATED_MAPPINGS_CLASS_NAME} generated by
 * the {@link TestContextAotGenerator} to retrieve the mappings generated during
 * AOT processing.
 *
 * @author Sam Brannen
 * @author Stephane Nicoll
 * @since 6.0
 */
public class AotTestMappings {

	// TODO Add support in ClassNameGenerator for supplying a predefined class name.
	// There is a similar issue in Spring Boot where code relies on a generated name.
	// Ideally we would generate a class named: org.springframework.test.context.aot.GeneratedAotTestMappings
	static final String GENERATED_MAPPINGS_CLASS_NAME = AotTestMappings.class.getName() + "__Generated";

	static final String GENERATED_MAPPINGS_METHOD_NAME = "getContextInitializers";

	private final Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> contextInitializers;


	public AotTestMappings() {
		this(GENERATED_MAPPINGS_CLASS_NAME);
	}

	AotTestMappings(String initializerClassName) {
		this(loadContextInitializersMap(initializerClassName));
	}

	AotTestMappings(Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> contextInitializers) {
		this.contextInitializers = contextInitializers;
	}


	/**
	 * Determine if the specified test class has an AOT-optimized application context
	 * initializer.
	 * <p>If this method returns {@code true}, {@link #getContextInitializer(Class)}
	 * should not return {@code null}.
	 */
	public boolean isSupportedTestClass(Class<?> testClass) {
		return this.contextInitializers.containsKey(testClass.getName());
	}

	/**
	 * Get the AOT {@link ApplicationContextInitializer} for the specified test class.
	 * @return the AOT context initializer, or {@code null} if there is no AOT context
	 * initializer for the specified test class
	 * @see #isSupportedTestClass(Class)
	 */
	@Nullable
	public ApplicationContextInitializer<ConfigurableApplicationContext> getContextInitializer(Class<?> testClass) {
		Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>> supplier =
				this.contextInitializers.get(testClass.getName());
		return (supplier != null ? supplier.get() : null);
	}


	@SuppressWarnings({ "rawtypes", "unchecked" })
	private static Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>>
			loadContextInitializersMap(String className) {

		String methodName = GENERATED_MAPPINGS_METHOD_NAME;

		try {
			Class<?> clazz = ClassUtils.forName(className, null);
			Method method = ReflectionUtils.findMethod(clazz, methodName);
			Assert.state(method != null, () -> "No %s() method found in %s".formatted(methodName, clazz.getName()));
			return (Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>>)
					ReflectionUtils.invokeMethod(method, null);
		}
		catch (IllegalStateException ex) {
			throw ex;
		}
		catch (Exception ex) {
			throw new IllegalStateException("Failed to invoke %s() method in %s".formatted(methodName, className), ex);
		}
	}

}
